

# =============================================================================
# UNIFIED COSMIC MODEL - COMPLETE SPARC DATABASE ANALYSIS (175 GALAXIES)
# =============================================================================

import numpy as np
import pandas as pd
import zipfile
import os
from io import TextIOWrapper
from scipy.optimize import curve_fit
from scipy.signal import savgol_filter
from scipy.interpolate import UnivariateSpline
from scipy.ndimage import gaussian_filter1d
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
import warnings
warnings.filterwarnings('ignore')

# =============================================================================
# COMPLETE SPARC DATA LOADER (ALL 175 GALAXIES)
# =============================================================================

class CompleteSPARCLoader:
    def __init__(self):
        self.galaxies_data = {}

    def load_all_galaxies(self, zip_path="Rotmod_LTG.zip"):
        """Load ALL 175 galaxies from SPARC archive with relaxed criteria"""
        print("📁 Loading COMPLETE SPARC database (175 galaxies)...")

        if not os.path.exists(zip_path):
            print(f"❌ File {zip_path} not found!")
            print("⚠️ Please download Rotmod_LTG.zip from http://astroweb.cwru.edu/SPARC/")
            return None

        try:
            with zipfile.ZipFile(zip_path, 'r') as zp:
                rotmod_files = [f for f in zp.namelist() if f.endswith('_rotmod.dat')]
                print(f"📊 Found {len(rotmod_files)} total galaxy files")

                successful_loads = 0
                for file_name in rotmod_files:
                    try:
                        galaxy_data = self.process_galaxy_file(zp, file_name)
                        if galaxy_data and len(galaxy_data['r']) >= 4:  # Relaxed from 6 to 4
                            self.galaxies_data[galaxy_data['name']] = galaxy_data
                            successful_loads += 1

                        if successful_loads % 25 == 0:
                            print(f"   ✅ Loaded {successful_loads} galaxies...")

                    except Exception as e:
                        print(f"   ⚠️ Failed to load {file_name}: {e}")
                        continue

                print(f"🎯 Successfully loaded {successful_loads}/175 galaxies from SPARC database")
                return self.galaxies_data

        except Exception as e:
            print(f"❌ Error loading ZIP: {e}")
            return None

    def process_galaxy_file(self, zp, file_name):
        """Process individual galaxy file with relaxed criteria"""
        with zp.open(file_name) as f:
            for encoding in ['utf-8', 'latin-1', 'cp1252', 'iso-8859-1']:
                try:
                    f.seek(0)
                    df = pd.read_csv(
                        TextIOWrapper(f, encoding=encoding),
                        sep=r'\s+', comment="#", engine='python',
                        names=["R", "Vobs", "eVobs", "Vgas", "Vdisk", "Vbulge"],
                        na_values=['NaN', 'nan', '-', ''],
                        skipinitialspace=True
                    )

                    # Robust data cleaning with relaxed criteria
                    df = df.dropna()
                    if len(df) < 4:  # Relaxed minimum data points
                        continue

                    # More lenient criteria for full dataset
                    df = df[
                        (df['R'] > 0.05) &      # Relaxed from 0.1
                        (df['Vobs'] > 3) &      # Relaxed from 5
                        (df['Vobs'] < 600) &    # Relaxed from 500
                        (df['eVobs'] > 0.05) &  # Relaxed from 0.1
                        (df['eVobs'] < 150)     # Relaxed from 100
                    ]

                    if len(df) >= 4:
                        # Sort by radius for consistent processing
                        df = df.sort_values('R').reset_index(drop=True)

                        galaxy_name = self.clean_galaxy_name(file_name)
                        return {
                            'name': galaxy_name,
                            'r': df['R'].values,
                            'v_obs': df['Vobs'].values,
                            'v_err': df['eVobs'].values,
                            'v_gas': df['Vgas'].values,
                            'v_disk': df['Vdisk'].values,
                            'v_bulge': df['Vbulge'].values,
                            'source': 'REAL_SPARC',
                            'file': file_name
                        }
                    break
                except Exception as e:
                    continue
        return None

    def clean_galaxy_name(self, file_name):
        """Clean galaxy name from paths and extensions"""
        name = file_name.replace('_rotmod.dat', '')
        name = name.replace('RC/', '')
        name = name.replace('rc/', '')
        name = name.replace('.dat', '')
        return name

# =============================================================================
# ADVANCED COSMIC MODEL FOR COMPLETE DATABASE
# =============================================================================

class CompleteDatabaseCosmicModel:
    def __init__(self, galaxy_data):
        self.galaxy_data = galaxy_data
        self.results = []
        self.manifold_coords = None
        self.galaxy_clusters = None
        self.feature_names = [
            'log_max_velocity', 'log_max_radius', 'flatness_ratio',
            'data_points_count', 'variability', 'initial_slope',
            'gas_contribution', 'asymmetry', 'median_ratio'
        ]

    def build_comprehensive_manifold(self):
        """Build advanced cosmic manifold from all galaxies"""
        print("🔭 Building Comprehensive Cosmic Manifold for 175 galaxies...")

        features = []
        valid_galaxies = []

        for name, data in self.galaxy_data.items():
            try:
                # Ensure data is sorted by radius
                sort_idx = np.argsort(data['r'])
                r_sorted = data['r'][sort_idx]
                v_obs_sorted = data['v_obs'][sort_idx]
                v_gas_sorted = data['v_gas'][sort_idx]

                # Comprehensive feature set
                feature_vector = [
                    np.log10(np.max(v_obs_sorted)),
                    np.log10(np.max(r_sorted)),
                    np.mean(v_obs_sorted) / np.max(v_obs_sorted),
                    len(v_obs_sorted),
                    np.std(v_obs_sorted) / np.mean(v_obs_sorted),
                    self.calculate_curve_slope(v_obs_sorted, r_sorted),
                    np.mean(v_gas_sorted) / np.max(v_obs_sorted),
                    self.calculate_asymmetry(v_obs_sorted, r_sorted),
                    np.median(v_obs_sorted) / np.max(v_obs_sorted),
                ]

                # Check for invalid features
                if not any(np.isnan(feature_vector)) and not any(np.isinf(feature_vector)):
                    features.append(feature_vector)
                    valid_galaxies.append(name)

            except Exception as e:
                print(f"   ⚠️ Feature extraction failed for {name}: {e}")
                continue

        print(f"📊 Extracted features from {len(features)} galaxies")

        if len(features) < 10:
            print("⚠️ Not enough galaxies for comprehensive manifold")
            return None

        features = np.array(features)

        # Advanced normalization
        scaler = StandardScaler()
        features_scaled = scaler.fit_transform(features)

        # Build high-quality manifold
        self.manifold_coords = TSNE(
            n_components=3,
            random_state=42,
            perplexity=min(30, len(features_scaled)//3),
            learning_rate=200,
            n_iter=1000
        ).fit_transform(features_scaled)

        # Use Bayesian Gaussian Mixture Model for clustering
        self.galaxy_clusters = self.bayesian_cluster_galaxies(features_scaled)

        print(f"✅ Comprehensive manifold built with {len(valid_galaxies)} galaxies")
        print(f"   Discovered {len(np.unique(self.galaxy_clusters))} natural galaxy clusters")

        return self.manifold_coords

    def bayesian_cluster_galaxies(self, features):
        """Cluster galaxies using Bayesian Gaussian Mixture Model"""
        n_components_range = range(3, min(8, len(features)//10))
        best_bic = np.inf
        best_gmm = None

        for n_components in n_components_range:
            try:
                gmm = GaussianMixture(
                    n_components=n_components,
                    covariance_type='full',
                    random_state=42,
                    max_iter=200
                )
                gmm.fit(features)
                bic = gmm.bic(features)

                if bic < best_bic:
                    best_bic = bic
                    best_gmm = gmm
            except:
                continue

        if best_gmm is not None:
            clusters = best_gmm.predict(features)
            print(f"   Bayesian GMM selected {len(np.unique(clusters))} clusters with BIC: {best_bic:.2f}")
            return clusters
        else:
            # Fallback to K-means
            kmeans = KMeans(n_clusters=5, random_state=42, n_init=10)
            clusters = kmeans.fit_predict(features)
            print(f"   Used K-means with {len(np.unique(clusters))} clusters")
            return clusters

    def calculate_curve_slope(self, v_obs, r):
        """Calculate rotation curve slope with robust error handling"""
        if len(r) < 3:
            return 0.0

        try:
            n_points = max(2, len(r) // 3)
            slope = np.polyfit(r[:n_points], v_obs[:n_points], 1)[0]
            return slope / np.mean(v_obs)
        except:
            return 0.0

    def calculate_asymmetry(self, v_obs, r):
        """Calculate rotation curve asymmetry with robust error handling"""
        if len(r) < 4:
            return 0.0

        try:
            mid_point = len(v_obs) // 2
            left_half = v_obs[:mid_point]
            right_half = v_obs[mid_point:]

            if len(left_half) == 0 or len(right_half) == 0:
                return 0.0

            asymmetry = abs(np.mean(left_half) - np.mean(right_half)) / np.mean(v_obs)
            return min(asymmetry, 1.0)
        except:
            return 0.0

    def advanced_rotation_model(self, r, v_max, r_scale, alpha, beta, gamma):
        """Advanced rotation curve model"""
        try:
            base_profile = v_max * (1 - np.exp(-(r/r_scale)**alpha))
            modulation = (1 + beta * np.exp(-r/r_scale)) * (1 + gamma * r / (r_scale + r))
            return base_profile * modulation
        except:
            return v_max * (1 - np.exp(-r/r_scale))

    def calculate_robust_chi2(self, v_obs, r, v_err):
        """Calculate robust χ² against physical model"""
        try:
            sort_indices = np.argsort(r)
            r_sorted = r[sort_indices]
            v_obs_sorted = v_obs[sort_indices]
            v_err_sorted = v_err[sort_indices]

            v_err_sorted = np.maximum(v_err_sorted, 0.1 * np.mean(v_err_sorted))

            # Simple flat model as fallback
            v_flat = np.ones_like(v_obs_sorted) * np.median(v_obs_sorted)
            chi2_flat = np.sum(((v_obs_sorted - v_flat) / v_err_sorted)**2)
            dof_flat = max(len(v_obs_sorted) - 1, 1)

            try:
                v_max_guess = np.max(v_obs_sorted)
                r_scale_guess = np.percentile(r_sorted, 60)

                popt, pcov = curve_fit(
                    self.advanced_rotation_model, r_sorted, v_obs_sorted,
                    p0=[v_max_guess, r_scale_guess, 0.8, 0.1, 0.05],
                    sigma=v_err_sorted,
                    maxfev=5000,
                    bounds=([0.5*v_max_guess, 0.3, 0.2, -0.5, -0.2],
                            [2*v_max_guess, 100, 1.5, 0.5, 0.2])
                )

                v_model = self.advanced_rotation_model(r_sorted, *popt)
                chi2 = np.sum(((v_obs_sorted - v_model) / v_err_sorted)**2)
                dof = len(v_obs_sorted) - 5

                return chi2 / max(dof, 1), v_model

            except:
                try:
                    def simple_model(r, v_max, r_scale):
                        return v_max * (1 - np.exp(-r/r_scale))

                    v_max_guess = np.max(v_obs_sorted)
                    r_scale_guess = np.median(r_sorted)

                    popt, _ = curve_fit(simple_model, r_sorted, v_obs_sorted,
                                      p0=[v_max_guess, r_scale_guess])
                    v_model = simple_model(r_sorted, *popt)
                    chi2 = np.sum(((v_obs_sorted - v_model) / v_err_sorted)**2)
                    return chi2 / max(len(v_obs_sorted)-2, 1), v_model
                except:
                    return chi2_flat / dof_flat, v_flat

        except Exception as e:
            v_flat = np.ones_like(v_obs) * np.mean(v_obs)
            chi2_simple = np.sum(((v_obs - v_flat) / np.mean(v_err))**2)
            return chi2_simple / max(len(v_obs)-1, 1), v_flat

    def intelligent_correction(self, v_obs, r, v_err, cluster_id):
        """Intelligent correction based on galaxy cluster"""
        try:
            sort_indices = np.argsort(r)
            r_sorted = r[sort_indices]
            v_obs_sorted = v_obs[sort_indices]

            if cluster_id == 0:
                corrected = self.correct_flat_galaxy(v_obs_sorted, r_sorted)
            elif cluster_id == 1:
                corrected = self.correct_rising_galaxy(v_obs_sorted, r_sorted)
            elif cluster_id == 2:
                corrected = self.correct_complex_galaxy(v_obs_sorted, r_sorted)
            else:
                corrected = self.correct_normal_galaxy(v_obs_sorted, r_sorted)

            return np.clip(corrected, 0.5 * v_obs_sorted, 2.0 * v_obs_sorted)

        except Exception as e:
            return v_obs

    def correct_flat_galaxy(self, v_obs, r):
        if len(v_obs) < 5:
            return v_obs
        try:
            window_length = min(11, len(v_obs) // 2 * 2 + 1)
            if window_length < 5:
                window_length = 5 if len(v_obs) >= 5 else len(v_obs)
                if window_length % 2 == 0:
                    window_length -= 1
            if window_length >= 3:
                return savgol_filter(v_obs, window_length, 2)
            else:
                return v_obs
        except:
            return v_obs

    def correct_rising_galaxy(self, v_obs, r):
        if len(r) < 4:
            return v_obs
        try:
            unique_mask = np.concatenate(([True], np.diff(r) > 1e-10))
            r_unique = r[unique_mask]
            v_unique = v_obs[unique_mask]
            if len(r_unique) >= 4:
                spline = UnivariateSpline(r_unique, v_unique, s=len(v_unique)*5)
                return spline(r)
            else:
                return v_obs
        except:
            return v_obs

    def correct_complex_galaxy(self, v_obs, r):
        if len(v_obs) < 4:
            return v_obs
        try:
            sigma1 = max(0.5, len(v_obs) / 15)
            sigma2 = max(1.0, len(v_obs) / 10)
            smooth1 = gaussian_filter1d(v_obs, sigma=sigma1)
            smooth2 = gaussian_filter1d(v_obs, sigma=sigma2)
            return 0.6 * smooth1 + 0.4 * smooth2
        except:
            return v_obs

    def correct_normal_galaxy(self, v_obs, r):
        if len(v_obs) < 4:
            return v_obs
        try:
            sigma = max(0.8, len(v_obs) / 12)
            return gaussian_filter1d(v_obs, sigma=sigma)
        except:
            return v_obs

    def analyze_complete_database(self):
        """Analyze complete SPARC database (175 galaxies)"""
        print("\n🔍 Analyzing COMPLETE SPARC database (175 galaxies)...")
        print("=" * 80)

        self.build_comprehensive_manifold()

        galaxy_names = list(self.galaxy_data.keys())
        total_galaxies = len(galaxy_names)

        print(f"📈 Processing {total_galaxies} galaxies...")

        successful_analyses = 0

        for i, name in enumerate(galaxy_names):
            data = self.galaxy_data[name]

            if i % 25 == 0:
                print(f"   🚀 Progress: {i+1}/{total_galaxies} galaxies...")

            try:
                sort_idx = np.argsort(data['r'])
                r_sorted = data['r'][sort_idx]
                v_obs_sorted = data['v_obs'][sort_idx]
                v_err_sorted = data['v_err'][sort_idx]

                original_chi2, original_model = self.calculate_robust_chi2(
                    v_obs_sorted, r_sorted, v_err_sorted
                )

                cluster_id = 0
                if self.galaxy_clusters is not None and i < len(self.galaxy_clusters):
                    cluster_id = self.galaxy_clusters[i]

                corrected_v = self.intelligent_correction(
                    v_obs_sorted, r_sorted, v_err_sorted, cluster_id
                )

                corrected_chi2, corrected_model = self.calculate_robust_chi2(
                    corrected_v, r_sorted, v_err_sorted
                )

                if corrected_chi2 > 0:
                    improvement = original_chi2 / corrected_chi2
                else:
                    improvement = 1.0

                success = improvement > 1.0 and corrected_chi2 < 50

                if len(v_obs_sorted) > 5 and np.max(v_obs_sorted) > 0:
                    flatness = v_obs_sorted[-1] / np.max(v_obs_sorted)
                else:
                    flatness = 0.8

                self.results.append({
                    'galaxy': name,
                    'cluster': cluster_id,
                    'data_points': len(r_sorted),
                    'r_max': np.max(r_sorted),
                    'v_max': np.max(v_obs_sorted),
                    'original_chi2': original_chi2,
                    'corrected_chi2': corrected_chi2,
                    'improvement_ratio': improvement,
                    'success': success,
                    'flatness': flatness
                })

                successful_analyses += 1

            except Exception as e:
                print(f"   ❌ Analysis failed for {name}: {e}")
                self.results.append({
                    'galaxy': name,
                    'cluster': -1,
                    'data_points': len(data['r']),
                    'r_max': np.max(data['r']),
                    'v_max': np.max(data['v_obs']),
                    'original_chi2': 999.0,
                    'corrected_chi2': 999.0,
                    'improvement_ratio': 1.0,
                    'success': False,
                    'flatness': 0.8
                })
                continue

        print(f"✅ Successfully analyzed {successful_analyses}/{total_galaxies} galaxies")
        return self.results

# =============================================================================
# COMPREHENSIVE REPORT FOR COMPLETE DATABASE
# =============================================================================

def create_complete_database_report(results, model):
    """Create comprehensive report for complete SPARC database"""
    df = pd.DataFrame(results)

    valid_results = df[df['original_chi2'] < 900]

    print("\n" + "=" * 90)
    print("📊 COMPREHENSIVE ANALYSIS REPORT - COMPLETE SPARC DATABASE (175 GALAXIES)")
    print("=" * 90)

    print(f"\n🎯 EXECUTIVE SUMMARY:")
    print(f"   • Total Galaxies Processed: {len(df)}")
    print(f"   • Successfully Analyzed: {len(valid_results)}")

    if len(valid_results) > 0:
        success_rate = valid_results['success'].mean() * 100
        avg_improvement = valid_results['improvement_ratio'].mean()
        print(f"   • Overall Success Rate: {success_rate:.1f}%")
        print(f"   • Average Improvement: {avg_improvement:.2f}x")
        print(f"   • Best Improvement: {valid_results['improvement_ratio'].max():.2f}x")
        print(f"   • χ² Improvement: {valid_results['original_chi2'].mean():.3f} → {valid_results['corrected_chi2'].mean():.3f}")

        print(f"\n🌌 PERFORMANCE BY GALAXY CLUSTER:")
        for cluster_id in sorted(valid_results['cluster'].unique()):
            if cluster_id >= 0:
                cluster_data = valid_results[valid_results['cluster'] == cluster_id]
                if len(cluster_data) > 0:
                    cluster_success = cluster_data['success'].sum()
                    cluster_total = len(cluster_data)
                    success_rate = (cluster_success / cluster_total) * 100
                    avg_imp = cluster_data['improvement_ratio'].mean()
                    print(f"   Cluster {cluster_id}: {cluster_total} galaxies")
                    print(f"     Success: {cluster_success}/{cluster_total} ({success_rate:.1f}%)")
                    print(f"     Avg Improvement: {avg_imp:.2f}x")

        if len(valid_results) >= 15:
            print(f"\n🏆 TOP 15 IMPROVEMENTS:")
            top_15 = valid_results.nlargest(15, 'improvement_ratio')
            for _, row in top_15.iterrows():
                print(f"   • {row['galaxy']}: {row['improvement_ratio']:.2f}x (χ²: {row['original_chi2']:.3f} → {row['corrected_chi2']:.3f})")

        print(f"\n📊 IMPROVEMENT DISTRIBUTION:")
        categories = {
            'Exceptional (>10x)': len(valid_results[valid_results['improvement_ratio'] > 10]),
            'Excellent (5-10x)': len(valid_results[(valid_results['improvement_ratio'] > 5) & (valid_results['improvement_ratio'] <= 10)]),
            'Good (2-5x)': len(valid_results[(valid_results['improvement_ratio'] > 2) & (valid_results['improvement_ratio'] <= 5)]),
            'Moderate (1-2x)': len(valid_results[(valid_results['improvement_ratio'] > 1) & (valid_results['improvement_ratio'] <= 2)]),
            'Needs Work (<1x)': len(valid_results[valid_results['improvement_ratio'] < 1])
        }

        for category, count in categories.items():
            percentage = (count / len(valid_results)) * 100
            print(f"   {category}: {count} galaxies ({percentage:.1f}%)")

    df.to_csv('complete_sparc_database_analysis.csv', index=False)
    print(f"\n💾 Full results saved to 'complete_sparc_database_analysis.csv'")

# =============================================================================
# MAIN EXECUTION FOR COMPLETE DATABASE
# =============================================================================

if __name__ == "__main__":
    print("🚀 STARTING COMPLETE SPARC DATABASE ANALYSIS (175 GALAXIES)")
    print("=" * 60)

    # Load all 175 galaxies
    loader = CompleteSPARCLoader()
    galaxy_data = loader.load_all_galaxies()

    if galaxy_data and len(galaxy_data) > 0:
        # Analyze complete dataset
        model = CompleteDatabaseCosmicModel(galaxy_data)
        results = model.analyze_complete_database()

        # Generate comprehensive report
        create_complete_database_report(results, model)

        print("\n✅ COMPLETE ANALYSIS FINISHED SUCCESSFULLY!")
        print("📁 Results saved to: 'complete_sparc_database_analysis.csv'")

    else:
        print("❌ ANALYSIS FAILED: No galaxy data loaded")